-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][linalg] Migrate Detensorize pass to new dialect conversion driver #152912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][linalg] Migrate Detensorize pass to new dialect conversion driver #152912
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) ChangesThe pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver. Full diff: https://github.com/llvm/llvm-project/pull/152912.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..221f95a8d8f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -458,6 +458,22 @@ struct LinalgDetensorize
}
};
+ /// A listener that forwards notifyBlockErased and notifyOperationErased to
+ /// the given callbacks.
+ struct CallbackListener : public RewriterBase::Listener {
+ CallbackListener(std::function<void(Operation *op)> onOperationErased,
+ std::function<void(Block *block)> onBlockErased)
+ : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
+
+ void notifyBlockErased(Block *block) override { onBlockErased(block); }
+ void notifyOperationErased(Operation *op) override {
+ onOperationErased(op);
+ }
+
+ std::function<void(Operation *op)> onOperationErased;
+ std::function<void(Block *block)> onBlockErased;
+ };
+
void runOnOperation() override {
MLIRContext *context = &getContext();
DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
shouldConvertBranchOperand);
- if (failed(
- applyFullConversion(getOperation(), target, std::move(patterns))))
+ ConversionConfig config;
+ auto onOperationErased = [&](Operation *op) {
+ opsToDetensor.erase(op);
+ detensorableBranchOps.erase(op);
+ };
+ auto onBlockErased = [&](Block *block) {
+ for (BlockArgument arg : block->getArguments()) {
+ blockArgsToDetensor.erase(arg);
+ }
+ };
+ CallbackListener listener(onOperationErased, onBlockErased);
+
+ config.listener = &listener;
+ config.allowPatternRollback = false;
+ if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
+ config)))
signalPassFailure();
RewritePatternSet canonPatterns(context);
diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 74931cb0830bc..5c29b04630cad 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
}
// CHECK-LABEL: func @detensor_op_sequence
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
-// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
-// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
-// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
+// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
+// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
+// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: return %[[new_tensor_res]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some nit, but I'll let others review and approve.
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] | ||
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] | ||
// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]] | ||
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHECK-DAG can come in any order, but the op here specifies them in a particular order, and the two arg1
vals have the same pattern. This may randomly fail.
|
||
config.listener = &listener; | ||
config.allowPatternRollback = false; | ||
if (failed(applyFullConversion(getOperation(), target, std::move(patterns), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the callback doesn't return anything, it won't change the success/failure result, so this should be fine.
09d8ee9
to
2422ce2
Compare
2422ce2
to
5ade1f5
Compare
The pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver.
Depends on #151865.